# -*- coding:utf-8 -*-

import numpy as np
from base.main import train, predict
from tensorflow.examples.tutorials.mnist import input_data
from utils.io_utils import convert_abspath
import config.glob.config as global_cfg

"""
LeNet5 模型测试
"""

lenet5_train_config = convert_abspath('config/core/train/single_layer.yml')
lenet5_predict_config = convert_abspath('config/core/predict/single_layer.yml')


def single_layer_dnn_train():
    """
    LeNet5 训练
    :return:
    """
    train(lenet5_train_config)


def single_layer_dnn_predict():
    """
    LeNet5 预测
    :return:
    """
    # 抽取两条
    mnist_dir = global_cfg.DATASET_ROOT + '\\mnist\\'
    mnist_train = input_data.read_data_sets(mnist_dir, one_hot=True).test
    x, y = mnist_train.next_batch(2)
    x = np.reshape(x, [-1, 784])
    # 预测
    y_pred = predict(lenet5_predict_config, x)
    print('\033[35my:{}, y_pred:{}'.format(np.argmax(y, axis=1), np.argmax(y_pred, axis=1)))


if __name__ == '__main__':
    single_layer_dnn_train()
    # single_layer_dnn_predict()
